{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Training a forward model for `Reacher-v4`\n",
    "\n",
    "This is a supplementary notebook showing how a forward model can be trained for the MuJoCo environment `Reacher-v4`.\n",
    "At the end, this notebook generates and saves a pickle file which stores the newly trained forward model.\n",
    "The generated pickle file can be used with the model predictive control example."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {},
   "source": [
    "## Requirements\n",
    "\n",
    "Although not a dependency of EvoTorch, this notebook uses [skorch](https://github.com/skorch-dev/skorch) for the required supervised learning operations. `skorch` can be installed via:\n",
    "\n",
    "```bash\n",
    "pip install skorch\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "## Initial imports\n",
    "\n",
    "We begin our code with initial imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import gymnasium as gym\n",
    "from typing import Iterable\n",
    "import multiprocessing as mp\n",
    "import math\n",
    "from torch import nn\n",
    "from skorch import NeuralNetRegressor"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {},
   "source": [
    "## Declarations\n",
    "\n",
    "We declare the environment name below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "ENV_NAME = \"Reacher-v4\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "By default, we use all the available CPUs of the local computer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_PROCESSES = mp.cpu_count()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "We are going to collect data from this many episodes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_EPISODES = 20000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "## Utilities for training\n",
    "\n",
    "Here, we define helper functions and utilities for the training of our model.\n",
    "\n",
    "We begin by defining the function $\\text{reacher\\_state}(\\cdot)$ which, given an observation from the reinforcement learning environment `Reacher-v4`, extracts and returns the state vector of the simulated robotic arm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reacher_state(observation: Iterable) -> Iterable:\n",
    "    observation = np.asarray(observation, dtype=\"float32\")\n",
    "    state = np.concatenate([observation[:4], observation[6:10]])\n",
    "    state[-2] += observation[4]\n",
    "    state[-1] += observation[5]\n",
    "    return state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "We now define a function $\\text{collect\\_data}(\\cdot)$ which collects data from multiple episodes, number of these episodes being specified via the argument `num_episodes`.\n",
    "Within each episode, the data we collect is:\n",
    "\n",
    "- current state\n",
    "- action (uniformly sampled)\n",
    "- next state (i.e. the state obtained after applying the action)\n",
    "\n",
    "The forward model that we wish to train should be able to answer this question: _given the current state and the action, what is the prediction for the next state?_ Therefore, among the data we collect, the current states and the actions are categorized as the inputs, while the next states are categorized as the targets.\n",
    "The function $\\text{collect\\_data}(\\cdot)$ organizes its data into inputs and targets, and finally returns them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_data(num_episodes: int):\n",
    "    inputs = []\n",
    "    targets = []\n",
    "\n",
    "    env = gym.make(ENV_NAME)\n",
    "    for _ in range(num_episodes):\n",
    "        observation, _ = env.reset()\n",
    "\n",
    "        while True:\n",
    "            action = np.clip(np.asarray(env.action_space.sample(), dtype=\"float32\"), -1.0, 1.0)\n",
    "            state = reacher_state(observation)\n",
    "            \n",
    "            observation, reward, terminated, truncated, info = env.step(action)\n",
    "            done = terminated | truncated\n",
    "\n",
    "            next_state = reacher_state(observation)\n",
    "\n",
    "            current_input = np.concatenate([state, action])\n",
    "            current_target = next_state - state\n",
    "            \n",
    "            inputs.append(current_input)\n",
    "            targets.append(current_target)\n",
    "            \n",
    "            if done:\n",
    "                break\n",
    "    \n",
    "    return np.vstack(inputs), np.vstack(targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14",
   "metadata": {},
   "source": [
    "The function below uses multiple CPUs of the local computer to collect data in parallel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_data_in_parallel(num_episodes: int):\n",
    "    n = math.ceil(num_episodes / NUM_PROCESSES)\n",
    "    \n",
    "    with mp.Pool(NUM_PROCESSES) as p:\n",
    "        collected_data = p.map(collect_data, [n for _ in range(NUM_PROCESSES)])\n",
    "    \n",
    "    all_inputs = []\n",
    "    all_targets = []\n",
    "    \n",
    "    for inp, target in collected_data:\n",
    "        all_inputs.append(inp)\n",
    "        all_targets.append(target)\n",
    "    \n",
    "    all_inputs = np.vstack(all_inputs)\n",
    "    all_targets = np.vstack(all_targets)\n",
    "    \n",
    "    return all_inputs, all_targets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16",
   "metadata": {},
   "source": [
    "To make the supervised learning procedure more efficient, we also introduce a normalizer.\n",
    "This normalizing function receives a batch (i.e. a collection) of vectors (where this batch can be the input data or the output data), and returns:\n",
    "\n",
    "- the normalized counterpart of the entire data\n",
    "- mean of the data\n",
    "- standard deviation of the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize(x: np.ndarray) -> tuple:\n",
    "    mean = np.mean(x, axis=0).astype(\"float32\")\n",
    "    stdev = np.clip(np.std(x, axis=0).astype(\"float32\"), 1e-5, np.inf)\n",
    "    normalized = np.asarray((x - mean) / stdev, dtype=\"float32\")\n",
    "    return normalized, mean, stdev"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {},
   "source": [
    "We are now ready to collect our data and store them.\n",
    "\n",
    "The following class (not to be instantiated) serves as a namespace where all our collected data and their stats (i.e. means and standard deviations) are stored.\n",
    "The rest of this notebook will refer to this namespace when training, saving, and testing the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "class data:\n",
    "    inputs = None\n",
    "    targets = None\n",
    "\n",
    "    input_mean = None\n",
    "    input_stdev = None\n",
    "\n",
    "    target_mean = None\n",
    "    target_stdev = None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20",
   "metadata": {},
   "source": [
    "Below, we collect the data and their stats, and store them in the `data` namespace."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.inputs, data.targets = collect_data_in_parallel(NUM_EPISODES)\n",
    "\n",
    "data.inputs, data.input_mean, data.input_stdev = normalize(data.inputs)\n",
    "data.targets, data.target_mean, data.target_stdev = normalize(data.targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.inputs.shape, data.targets.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.input_mean.shape, data.input_stdev.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.target_mean.shape, data.target_stdev.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25",
   "metadata": {},
   "source": [
    "We declare the following architecture for our neural network:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nn.Sequential(\n",
    "    nn.Linear(10, 64),\n",
    "    nn.Tanh(),\n",
    "    nn.LayerNorm(64),\n",
    "    nn.Linear(64, 64),\n",
    "    nn.Tanh(),\n",
    "    nn.LayerNorm(64),\n",
    "    nn.Linear(64, 8),\n",
    ")\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27",
   "metadata": {},
   "source": [
    "Declare a regression problem and set the values of the hyperparameters to be used for the training procedure:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28",
   "metadata": {},
   "outputs": [],
   "source": [
    "regressor = NeuralNetRegressor(\n",
    "    model,\n",
    "    max_epochs=50,\n",
    "    lr=0.0001,\n",
    "    optimizer=torch.optim.Adam,\n",
    "    iterator_train__shuffle=True,\n",
    "    batch_size=500,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29",
   "metadata": {},
   "source": [
    "Train the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30",
   "metadata": {},
   "outputs": [],
   "source": [
    "regressor.fit(data.inputs, data.targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31",
   "metadata": {},
   "source": [
    "At this point, we should have a trained model.\n",
    "\n",
    "To test this trained model, we define the convenience function below which receives the current state and an action, and with the help of the trained model, returns the prediction for the next state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def use_net(state: Iterable, action: Iterable) -> Iterable:\n",
    "    input_mean = torch.as_tensor(data.input_mean, dtype=torch.float32)\n",
    "    input_stdev = torch.as_tensor(data.input_stdev, dtype=torch.float32)\n",
    "    target_mean = torch.as_tensor(data.target_mean, dtype=torch.float32)\n",
    "    target_stdev = torch.as_tensor(data.target_stdev, dtype=torch.float32)\n",
    "    \n",
    "    state = torch.as_tensor(state, dtype=torch.float32)\n",
    "    action = torch.clamp(torch.as_tensor(action, dtype=torch.float32), -1.0, 1.0)\n",
    "    \n",
    "    x = torch.cat([state, action])    \n",
    "    x = (x - input_mean) / input_stdev\n",
    "    y = model(x)\n",
    "    y = (y * target_stdev) + target_mean\n",
    "    result = (y + state).numpy()\n",
    "\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33",
   "metadata": {},
   "source": [
    "To compare the predictions of our model against the actual states, we instantiate a `Reacher-v4` environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(ENV_NAME)\n",
    "env"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35",
   "metadata": {},
   "source": [
    "In the code below, we have a loop which feeds both the actual `Reacher-v4` environment and our trained predictor the same actions.\n",
    "During the execution of this loop, the x and y coordinates of the robotic arm's tip, reported both by the actual environment and by the trained predictor are collected.\n",
    "At the end, the collected x and y coordinates are plotted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36",
   "metadata": {},
   "outputs": [],
   "source": [
    "observation, _ = env.reset()\n",
    "observation = np.asarray(observation, dtype=\"float32\")\n",
    "\n",
    "actual_state = reacher_state(observation)\n",
    "pred_state = actual_state.copy()\n",
    "\n",
    "class actual:\n",
    "    x = []\n",
    "    y = []\n",
    "\n",
    "actual.x.append(actual_state[-2])\n",
    "actual.y.append(actual_state[-1])    \n",
    "\n",
    "class predicted:\n",
    "    x = []\n",
    "    y = []\n",
    "\n",
    "predicted.x.append(pred_state[-2])\n",
    "predicted.y.append(pred_state[-1])    \n",
    "\n",
    "while True:\n",
    "    action = np.asarray(env.action_space.sample(), dtype=\"float32\")\n",
    "    \n",
    "    observation, reward, terminated, truncated, info = env.step(action)\n",
    "    done = terminated | truncated\n",
    "\n",
    "    actual_state = reacher_state(observation)\n",
    "    \n",
    "    pred_state = use_net(pred_state, action)\n",
    "\n",
    "    actual.x.append(actual_state[-2])\n",
    "    actual.y.append(actual_state[-1])    \n",
    "\n",
    "    predicted.x.append(pred_state[-2])\n",
    "    predicted.y.append(pred_state[-1])    \n",
    "\n",
    "    if done:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(actual.x)\n",
    "plt.plot(predicted.x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(actual.y)\n",
    "plt.plot(predicted.y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40",
   "metadata": {},
   "source": [
    "Below, we save our trained model.\n",
    "This trained model can be used by the `Reacher-v4` MPC example notebook, if copied next to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open(\"reacher_model.pickle\", \"wb\") as f:\n",
    "    pickle.dump(\n",
    "        {\n",
    "            \"model\": model,\n",
    "            \"input_mean\": data.input_mean,\n",
    "            \"input_stdev\": data.input_stdev,\n",
    "            \"target_mean\": data.target_mean,\n",
    "            \"target_stdev\": data.target_stdev,\n",
    "        },\n",
    "        f\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
